// Copyright 2021 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package adapt

import (
	"encoding/base64"
	"fmt"
	"strings"

	storagepb "google.golang.org/genproto/googleapis/cloud/bigquery/storage/v1"
	"google.golang.org/protobuf/proto"
	"google.golang.org/protobuf/reflect/protodesc"
	"google.golang.org/protobuf/reflect/protoreflect"
	"google.golang.org/protobuf/types/descriptorpb"
	"google.golang.org/protobuf/types/known/wrapperspb"
)

var bqModeToFieldLabelMapProto2 = map[storagepb.TableFieldSchema_Mode]descriptorpb.FieldDescriptorProto_Label{
	storagepb.TableFieldSchema_NULLABLE: descriptorpb.FieldDescriptorProto_LABEL_OPTIONAL,
	storagepb.TableFieldSchema_REPEATED: descriptorpb.FieldDescriptorProto_LABEL_REPEATED,
	storagepb.TableFieldSchema_REQUIRED: descriptorpb.FieldDescriptorProto_LABEL_REQUIRED,
}

var bqModeToFieldLabelMapProto3 = map[storagepb.TableFieldSchema_Mode]descriptorpb.FieldDescriptorProto_Label{
	storagepb.TableFieldSchema_NULLABLE: descriptorpb.FieldDescriptorProto_LABEL_OPTIONAL,
	storagepb.TableFieldSchema_REPEATED: descriptorpb.FieldDescriptorProto_LABEL_REPEATED,
	storagepb.TableFieldSchema_REQUIRED: descriptorpb.FieldDescriptorProto_LABEL_OPTIONAL,
}

func convertModeToLabel(mode storagepb.TableFieldSchema_Mode, useProto3 bool) *descriptorpb.FieldDescriptorProto_Label {
	if useProto3 {
		return bqModeToFieldLabelMapProto3[mode].Enum()
	}
	return bqModeToFieldLabelMapProto2[mode].Enum()
}

// Allows conversion between BQ schema type and FieldDescriptorProto's type.
var bqTypeToFieldTypeMap = map[storagepb.TableFieldSchema_Type]descriptorpb.FieldDescriptorProto_Type{
	storagepb.TableFieldSchema_BIGNUMERIC: descriptorpb.FieldDescriptorProto_TYPE_BYTES,
	storagepb.TableFieldSchema_BOOL:       descriptorpb.FieldDescriptorProto_TYPE_BOOL,
	storagepb.TableFieldSchema_BYTES:      descriptorpb.FieldDescriptorProto_TYPE_BYTES,
	storagepb.TableFieldSchema_DATE:       descriptorpb.FieldDescriptorProto_TYPE_INT32,
	storagepb.TableFieldSchema_DATETIME:   descriptorpb.FieldDescriptorProto_TYPE_INT64,
	storagepb.TableFieldSchema_DOUBLE:     descriptorpb.FieldDescriptorProto_TYPE_DOUBLE,
	storagepb.TableFieldSchema_GEOGRAPHY:  descriptorpb.FieldDescriptorProto_TYPE_STRING,
	storagepb.TableFieldSchema_INT64:      descriptorpb.FieldDescriptorProto_TYPE_INT64,
	storagepb.TableFieldSchema_NUMERIC:    descriptorpb.FieldDescriptorProto_TYPE_BYTES,
	storagepb.TableFieldSchema_STRING:     descriptorpb.FieldDescriptorProto_TYPE_STRING,
	storagepb.TableFieldSchema_STRUCT:     descriptorpb.FieldDescriptorProto_TYPE_MESSAGE,
	storagepb.TableFieldSchema_TIME:       descriptorpb.FieldDescriptorProto_TYPE_INT64,
	storagepb.TableFieldSchema_TIMESTAMP:  descriptorpb.FieldDescriptorProto_TYPE_INT64,
}

// For TableFieldSchema OPTIONAL mode, we use the wrapper types to allow for the
// proper representation of NULL values, as proto3 semantics would just use default value.
var bqTypeToWrapperMap = map[storagepb.TableFieldSchema_Type]string{
	storagepb.TableFieldSchema_BIGNUMERIC: ".google.protobuf.BytesValue",
	storagepb.TableFieldSchema_BOOL:       ".google.protobuf.BoolValue",
	storagepb.TableFieldSchema_BYTES:      ".google.protobuf.BytesValue",
	storagepb.TableFieldSchema_DATE:       ".google.protobuf.Int32Value",
	storagepb.TableFieldSchema_DATETIME:   ".google.protobuf.Int64Value",
	storagepb.TableFieldSchema_DOUBLE:     ".google.protobuf.DoubleValue",
	storagepb.TableFieldSchema_GEOGRAPHY:  ".google.protobuf.StringValue",
	storagepb.TableFieldSchema_INT64:      ".google.protobuf.Int64Value",
	storagepb.TableFieldSchema_NUMERIC:    ".google.protobuf.BytesValue",
	storagepb.TableFieldSchema_STRING:     ".google.protobuf.StringValue",
	storagepb.TableFieldSchema_TIME:       ".google.protobuf.Int64Value",
	storagepb.TableFieldSchema_TIMESTAMP:  ".google.protobuf.Int64Value",
}

// filename used by well known types proto
var wellKnownTypesWrapperName = "google/protobuf/wrappers.proto"

// dependencyCache is used to reduce the number of unique messages we generate by caching based on the tableschema.
//
// keys are based on the base64-encoded serialized tableschema value.
type dependencyCache map[string]protoreflect.Descriptor

func (dm dependencyCache) get(schema *storagepb.TableSchema) protoreflect.Descriptor {
	if dm == nil {
		return nil
	}
	b, err := proto.Marshal(schema)
	if err != nil {
		return nil
	}
	encoded := base64.StdEncoding.EncodeToString(b)
	if desc, ok := (dm)[encoded]; ok {
		return desc
	}
	return nil
}

func (dm dependencyCache) add(schema *storagepb.TableSchema, descriptor protoreflect.Descriptor) error {
	if dm == nil {
		return fmt.Errorf("cache is nil")
	}
	b, err := proto.Marshal(schema)
	if err != nil {
		return fmt.Errorf("failed to serialize tableschema: %v", err)
	}
	encoded := base64.StdEncoding.EncodeToString(b)
	(dm)[encoded] = descriptor
	return nil
}

// StorageSchemaToProto2Descriptor builds a protoreflect.Descriptor for a given table schema using proto2 syntax.
func StorageSchemaToProto2Descriptor(inSchema *storagepb.TableSchema, scope string) (protoreflect.Descriptor, error) {
	dc := make(dependencyCache)
	// TODO: b/193064992 tracks support for wrapper types.  In the interim, disable wrapper usage.
	return storageSchemaToDescriptorInternal(inSchema, scope, &dc, false)
}

// StorageSchemaToProto3Descriptor builds a protoreflect.Descriptor for a given table schema using proto3 syntax.
//
// NOTE: Currently the write API doesn't yet support proto3 behaviors (default value, wrapper types, etc), but this is provided for
// completeness.
func StorageSchemaToProto3Descriptor(inSchema *storagepb.TableSchema, scope string) (protoreflect.Descriptor, error) {
	dc := make(dependencyCache)
	return storageSchemaToDescriptorInternal(inSchema, scope, &dc, true)
}

// internal implementation of the conversion code.
func storageSchemaToDescriptorInternal(inSchema *storagepb.TableSchema, scope string, cache *dependencyCache, useProto3 bool) (protoreflect.Descriptor, error) {
	if inSchema == nil {
		return nil, newConversionError(scope, fmt.Errorf("no input schema was provided"))
	}

	var fields []*descriptorpb.FieldDescriptorProto
	var deps []protoreflect.FileDescriptor
	var fNumber int32

	for _, f := range inSchema.GetFields() {
		fNumber = fNumber + 1
		currentScope := fmt.Sprintf("%s__%s", scope, f.GetName())
		// If we're dealing with a STRUCT type, we must deal with sub messages.
		// As multiple submessages may share the same type definition, we use a dependency cache
		// and interrogate it / populate it as we're going.
		if f.Type == storagepb.TableFieldSchema_STRUCT {
			foundDesc := cache.get(&storagepb.TableSchema{Fields: f.GetFields()})
			if foundDesc != nil {
				// check to see if we already have this in current dependency list
				haveDep := false
				for _, curDep := range deps {
					if foundDesc.ParentFile().FullName() == curDep.FullName() {
						haveDep = true
						break
					}
				}
				// if dep is missing, add to current dependencies
				if !haveDep {
					deps = append(deps, foundDesc.ParentFile())
				}
				// construct field descriptor for the message
				fdp, err := tableFieldSchemaToFieldDescriptorProto(f, fNumber, string(foundDesc.FullName()), useProto3)
				if err != nil {
					return nil, newConversionError(scope, fmt.Errorf("couldn't convert field to FieldDescriptorProto: %v", err))
				}
				fields = append(fields, fdp)
			} else {
				// Wrap the current struct's fields in a TableSchema outer message, and then build the submessage.
				ts := &storagepb.TableSchema{
					Fields: f.GetFields(),
				}
				desc, err := storageSchemaToDescriptorInternal(ts, currentScope, cache, useProto3)
				if err != nil {
					return nil, newConversionError(currentScope, fmt.Errorf("couldn't convert message: %v", err))
				}
				// Now that we have the submessage definition, we append it both to the local dependencies, as well
				// as inserting it into the cache for possible reuse elsewhere.
				deps = append(deps, desc.ParentFile())
				err = cache.add(ts, desc)
				if err != nil {
					return nil, newConversionError(currentScope, fmt.Errorf("failed to add descriptor to dependency cache: %v", err))
				}
				fdp, err := tableFieldSchemaToFieldDescriptorProto(f, fNumber, currentScope, useProto3)
				if err != nil {
					return nil, newConversionError(currentScope, fmt.Errorf("couldn't compute field schema : %v", err))
				}
				fields = append(fields, fdp)
			}
		} else {
			fd, err := tableFieldSchemaToFieldDescriptorProto(f, fNumber, currentScope, useProto3)
			if err != nil {
				return nil, newConversionError(currentScope, err)
			}
			fields = append(fields, fd)
		}
	}
	// Start constructing a DescriptorProto.
	dp := &descriptorpb.DescriptorProto{
		Name:  proto.String(scope),
		Field: fields,
	}

	// Use the local dependencies to generate a list of filenames.
	depNames := []string{
		wellKnownTypesWrapperName,
	}
	for _, d := range deps {
		depNames = append(depNames, d.ParentFile().Path())
	}

	// Now, construct a FileDescriptorProto.
	fdp := &descriptorpb.FileDescriptorProto{
		MessageType: []*descriptorpb.DescriptorProto{dp},
		Name:        proto.String(fmt.Sprintf("%s.proto", scope)),
		Syntax:      proto.String("proto3"),
		Dependency:  depNames,
	}
	if !useProto3 {
		fdp.Syntax = proto.String("proto2")
	}

	// We'll need a FileDescriptorSet as we have a FileDescriptorProto for the current
	// descriptor we're building, but we need to include all the referenced dependencies.
	fds := &descriptorpb.FileDescriptorSet{
		File: []*descriptorpb.FileDescriptorProto{
			fdp,
			protodesc.ToFileDescriptorProto(wrapperspb.File_google_protobuf_wrappers_proto),
		},
	}
	for _, d := range deps {
		fds.File = append(fds.File, protodesc.ToFileDescriptorProto(d))
	}

	// Load the set into a registry, then interrogate it for the descriptor corresponding to the top level message.
	files, err := protodesc.NewFiles(fds)
	if err != nil {
		return nil, err
	}
	return files.FindDescriptorByName(protoreflect.FullName(scope))
}

// tableFieldSchemaToFieldDescriptorProto builds individual field descriptors for a proto message.
//
// For proto3, in cases where the mode is nullable we use the well known wrapper types.
// For proto2, we propagate the mode->label annotation as expected.
//
// Messages are always nullable, and repeated fields are as well.
func tableFieldSchemaToFieldDescriptorProto(field *storagepb.TableFieldSchema, idx int32, scope string, useProto3 bool) (*descriptorpb.FieldDescriptorProto, error) {
	name := strings.ToLower(field.GetName())
	if field.GetType() == storagepb.TableFieldSchema_STRUCT {
		return &descriptorpb.FieldDescriptorProto{
			Name:     proto.String(name),
			Number:   proto.Int32(idx),
			TypeName: proto.String(scope),
			Label:    convertModeToLabel(field.GetMode(), useProto3),
		}, nil
	}

	// For (REQUIRED||REPEATED) fields for proto3, or all cases for proto2, we can use the expected scalar types.
	if field.GetMode() != storagepb.TableFieldSchema_NULLABLE || !useProto3 {
		return &descriptorpb.FieldDescriptorProto{
			Name:   proto.String(name),
			Number: proto.Int32(idx),
			Type:   bqTypeToFieldTypeMap[field.GetType()].Enum(),
			Label:  convertModeToLabel(field.GetMode(), useProto3),
		}, nil
	}
	// For NULLABLE proto3 fields, use a wrapper type.
	return &descriptorpb.FieldDescriptorProto{
		Name:     proto.String(name),
		Number:   proto.Int32(idx),
		Type:     descriptorpb.FieldDescriptorProto_TYPE_MESSAGE.Enum(),
		TypeName: proto.String(bqTypeToWrapperMap[field.GetType()]),
		Label:    descriptorpb.FieldDescriptorProto_LABEL_OPTIONAL.Enum(),
	}, nil
}

// NormalizeDescriptor builds a self-contained DescriptorProto suitable for communicating schema
// information with the BigQuery Storage write API.  It's primarily used for cases where users are
// interested in sending data using a predefined protocol buffer message.
//
// The storage API accepts a single DescriptorProto for decoding message data.  In many cases, a message
// is comprised of multiple independent messages, from the same .proto file or from multiple sources.  Rather
// than being forced to communicate all these messages independently, what this method does is rewrite the
// DescriptorProto to inline all messages as nested submessages.  As the backend only cares about the types
// and not the namespaces when decoding, this is sufficient for the needs of the API's representation.
//
// In addition to nesting messages, this method also handles some encapsulation of enum types to avoid possible
// conflicts due to ambiguities.
func NormalizeDescriptor(in protoreflect.MessageDescriptor) (*descriptorpb.DescriptorProto, error) {
	return normalizeDescriptorInternal(in, newStringSet(), newStringSet(), newStringSet(), nil)
}

func normalizeDescriptorInternal(in protoreflect.MessageDescriptor, visitedTypes, enumTypes, structTypes *stringSet, root *descriptorpb.DescriptorProto) (*descriptorpb.DescriptorProto, error) {
	if in == nil {
		return nil, fmt.Errorf("no messagedescriptor provided")
	}
	resultDP := &descriptorpb.DescriptorProto{}
	if root == nil {
		root = resultDP
	}
	fullProtoName := string(in.FullName())
	resultDP.Name = proto.String(normalizeName(fullProtoName))
	visitedTypes.add(fullProtoName)
	for i := 0; i < in.Fields().Len(); i++ {
		inField := in.Fields().Get(i)
		resultFDP := protodesc.ToFieldDescriptorProto(inField)
		if inField.Kind() == protoreflect.MessageKind || inField.Kind() == protoreflect.GroupKind {
			// Handle fields that reference messages.
			// Groups are a proto2-ism which predated nested messages.
			msgFullName := string(inField.Message().FullName())
			if !skipNormalization(msgFullName) {
				// for everything but well known types, normalize.
				normName := normalizeName(string(msgFullName))
				if structTypes.contains(msgFullName) {
					resultFDP.TypeName = proto.String(normName)
				} else {
					if visitedTypes.contains(msgFullName) {
						return nil, fmt.Errorf("recursize type not supported: %s", inField.FullName())
					}
					visitedTypes.add(msgFullName)
					dp, err := normalizeDescriptorInternal(inField.Message(), visitedTypes, enumTypes, structTypes, root)
					if err != nil {
						return nil, fmt.Errorf("error converting message %s: %v", inField.FullName(), err)
					}
					root.NestedType = append(root.NestedType, dp)
					visitedTypes.delete(msgFullName)
					lastNested := root.GetNestedType()[len(root.GetNestedType())-1].GetName()
					resultFDP.TypeName = proto.String(lastNested)
				}
			}
		}
		if inField.Kind() == protoreflect.EnumKind {
			// For enums, in order to avoid value conflict, we will always define
			// a enclosing struct called enum_full_name_E that includes the actual
			// enum.
			enumFullName := string(inField.Enum().FullName())
			enclosingTypeName := normalizeName(enumFullName) + "_E"
			enumName := string(inField.Enum().Name())
			actualFullName := fmt.Sprintf("%s.%s", enclosingTypeName, enumName)
			if enumTypes.contains(enumFullName) {
				resultFDP.TypeName = proto.String(actualFullName)
			} else {
				enumDP := protodesc.ToEnumDescriptorProto(inField.Enum())
				enumDP.Name = proto.String(enumName)
				resultDP.NestedType = append(resultDP.NestedType, &descriptorpb.DescriptorProto{
					Name:     proto.String(enclosingTypeName),
					EnumType: []*descriptorpb.EnumDescriptorProto{enumDP},
				})
				resultFDP.TypeName = proto.String(actualFullName)
				enumTypes.add(enumFullName)
			}
		}
		resultDP.Field = append(resultDP.Field, resultFDP)
	}
	structTypes.add(fullProtoName)
	return resultDP, nil
}

type stringSet struct {
	m map[string]struct{}
}

func (s *stringSet) contains(k string) bool {
	_, ok := s.m[k]
	return ok
}

func (s *stringSet) add(k string) {
	s.m[k] = struct{}{}
}

func (s *stringSet) delete(k string) {
	delete(s.m, k)
}

func newStringSet() *stringSet {
	return &stringSet{
		m: make(map[string]struct{}),
	}
}

func normalizeName(in string) string {
	return strings.Replace(in, ".", "_", -1)
}

// these types don't get normalized into the fully-contained structure.
var normalizationSkipList = []string{
	/*
		TODO: when backend supports resolving well known types, this list should be enabled.
		"google.protobuf.DoubleValue",
		"google.protobuf.FloatValue",
		"google.protobuf.Int64Value",
		"google.protobuf.UInt64Value",
		"google.protobuf.Int32Value",
		"google.protobuf.Uint32Value",
		"google.protobuf.BoolValue",
		"google.protobuf.StringValue",
		"google.protobuf.BytesValue",
	*/
}

func skipNormalization(fullName string) bool {
	for _, v := range normalizationSkipList {
		if v == fullName {
			return true
		}
	}
	return false
}
